"""
Phase 9: Systematic subsample investigation of Z×KAOPEN interactions.

The 140-country followup found Z×KAOPEN interactions lose significance in the
expanded sample. This script cuts the data by financial openness, OECD status,
time period, demographic stage, income×openness cross-tabs, and influential
observation exclusion to determine where the interactions survive.

Output: followup/output/tables/table_interaction_subsamples.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
from scipy import stats

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
sys.path.insert(0, str(FOLLOWUP_DIR))

from src.model import PanelGLS
from src.macro import (
    EBA_COUNTRIES, SSA_COUNTRIES, EU_EXPANSION, EXPANSION_TIER1,
    filter_eba_sample,
)

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR = FOLLOWUP_DIR / "data" / "processed"

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
INTERACTION_VARS = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

# Extended model controls (matches phase4e)
CONTROLS = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
            'log_rel_opw', 'health_exp_gdp']
RATE_VAR = ['log_lending_rate']

ALL_REGRESSORS = DEMO_VARS + CONTROLS + RATE_VAR + INTERACTION_VARS

# OECD members (as of 2024, using ISO3)
OECD_COUNTRIES = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK',
    'EST', 'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR',
    'ITA', 'JPN', 'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL',
    'NOR', 'POL', 'PRT', 'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR',
    'GBR', 'USA',
}

# DFBETA-flagged countries from Probe 3 (simple spec, 14 flagged)
DFBETA_FLAGGED = {
    'GNQ', 'KWT', 'LSO', 'SDN', 'ETH', 'BTN', 'SAU', 'NGA',
    'DZA', 'COM', 'CPV', 'MNG', 'AZE', 'LAO',
}


def load_panel():
    """Load panel, filter year <= 2024, apply expanded sample filter."""
    df = pd.read_csv(DATA_DIR / "full_panel.csv")
    df = df[df['year'] <= 2024].copy()
    df = filter_eba_sample(df, extended=True, expansion=True)
    df = df[(df['year'] >= 1986)].copy()
    print(f"Working panel: {len(df):,} obs, {df['iso3'].nunique()} countries, "
          f"years {df['year'].min()}-{df['year'].max()}")
    return df


def stars(p):
    if p < 0.01:
        return '***'
    elif p < 0.05:
        return '**'
    elif p < 0.1:
        return '*'
    return ''


def run_extended_model(df, label=''):
    """Run extended model with Z×KAOPEN interactions on a subsample.

    Returns dict with results or None if insufficient data.
    """
    est = df.dropna(subset=['ca_gdp'] + ALL_REGRESSORS + ['iso3', 'year']).copy()
    n_obs = len(est)
    n_countries = est['iso3'].nunique()

    if n_obs < 50 or n_countries < 5:
        print(f"  {label}: skip (N={n_obs}, countries={n_countries})")
        return None

    model = PanelGLS()
    model.fit(est['ca_gdp'].values, est[ALL_REGRESSORS].values,
              est['iso3'].values, est['year'].values)
    model.feature_names = ALL_REGRESSORS

    # Extract key coefficients
    result = {
        'label': label,
        'N': model.n_obs,
        'Countries': model.n_countries,
        'R2': model.r_squared,
    }

    # Z₁ coefficient
    idx_z1 = ALL_REGRESSORS.index('Z_1')
    result['Z1_coef'] = model.beta[idx_z1]
    result['Z1_p'] = model.pvalues[idx_z1]

    # Z×KAOPEN interaction coefficients
    for ivar in INTERACTION_VARS:
        idx = ALL_REGRESSORS.index(ivar)
        short = ivar.replace('_x_kaopen', '').replace('Z_', 'Z')  # Z1, Z2, Z3
        result[f'{short}xK_coef'] = model.beta[idx]
        result[f'{short}xK_p'] = model.pvalues[idx]

    # Joint F-test on the 3 interaction terms
    # Restricted model: same sample, same controls, but without interactions
    restricted_vars = DEMO_VARS + CONTROLS + RATE_VAR
    model_r = PanelGLS()
    model_r.fit(est['ca_gdp'].values, est[restricted_vars].values,
                est['iso3'].values, est['year'].values)

    q = len(INTERACTION_VARS)
    k_u = len(ALL_REGRESSORS)
    n = model.n_obs

    # R²-based F-test (same sample ensures valid comparison)
    delta_r2 = model.r_squared - model_r.r_squared
    if delta_r2 > 0:
        F_stat = (delta_r2 / q) / ((1 - model.r_squared) / (n - k_u - 1))
    else:
        # Fallback: Wald-style F from individual t-stats
        interact_indices = [ALL_REGRESSORS.index(iv) for iv in INTERACTION_VARS]
        t_sq = (model.beta[interact_indices] / model.se[interact_indices]) ** 2
        F_stat = np.sum(t_sq) / q

    df2 = n - k_u - 1
    F_p = 1 - stats.f.cdf(F_stat, q, df2)

    result['F_stat'] = F_stat
    result['F_p'] = F_p

    sig = stars(F_p)
    print(f"  {label}: N={n_obs}, C={n_countries}, R²={model.r_squared:.3f}, "
          f"Z₁={result['Z1_coef']:.1f}({stars(result['Z1_p'])}), "
          f"Z₁×K={result['Z1xK_coef']:.1f}({stars(result['Z1xK_p'])}), "
          f"Joint F={F_stat:.2f} p={F_p:.3f}{sig}")

    return result


def classify_countries(df):
    """Add country-level classifications for KAOPEN terciles, income, and demo stage."""
    # Country-level medians
    country_stats = df.groupby('iso3').agg(
        median_kaopen=('kaopen', 'median'),
        median_oadr=('old_dep', 'median'),
        median_gdp_pc=('gdp_pc_ppp', 'median'),
    ).dropna()

    # KAOPEN terciles
    k_q33 = country_stats['median_kaopen'].quantile(0.33)
    k_q67 = country_stats['median_kaopen'].quantile(0.67)
    country_stats['kaopen_tercile'] = pd.cut(
        country_stats['median_kaopen'],
        bins=[-np.inf, k_q33, k_q67, np.inf],
        labels=['low', 'medium', 'high'],
    )

    # Income terciles
    inc_q33 = country_stats['median_gdp_pc'].quantile(0.33)
    inc_q67 = country_stats['median_gdp_pc'].quantile(0.67)
    country_stats['income_tercile'] = pd.cut(
        country_stats['median_gdp_pc'],
        bins=[-np.inf, inc_q33, inc_q67, np.inf],
        labels=['low', 'middle', 'high'],
    )

    # Demographic stage (OADR terciles)
    d_q33 = country_stats['median_oadr'].quantile(0.33)
    d_q67 = country_stats['median_oadr'].quantile(0.67)
    country_stats['demo_stage'] = pd.cut(
        country_stats['median_oadr'],
        bins=[-np.inf, d_q33, d_q67, np.inf],
        labels=['young', 'transitioning', 'aging'],
    )

    # OECD
    country_stats['is_oecd'] = country_stats.index.isin(OECD_COUNTRIES)

    print(f"\nCountry classifications (N={len(country_stats)}):")
    print(f"  KAOPEN terciles: {country_stats['kaopen_tercile'].value_counts().to_dict()}")
    print(f"  Income terciles: {country_stats['income_tercile'].value_counts().to_dict()}")
    print(f"  Demo stage: {country_stats['demo_stage'].value_counts().to_dict()}")
    print(f"  OECD: {country_stats['is_oecd'].sum()} / {len(country_stats)}")

    return country_stats


def build_markdown_table(results):
    """Build the master markdown table from results list."""
    lines = [
        '# Interaction Subsample Analysis: Z×KAOPEN on 140-Country Panel',
        '',
        'Extended model: CA/GDP ~ Z₁ + Z₂ + Z₃ + EBA controls + log_lending_rate + Z×KAOPEN',
        '',
        '| Subsample | Z₁ | Z₁ p | Z₁×K | Z₁×K p | Z₂×K | Z₂×K p | Z₃×K | Z₃×K p | Joint F | F p | N | C | R² |',
        '|-----------|-----|------|------|--------|------|--------|------|--------|---------|-----|---|---|-----|',
    ]

    for r in results:
        if r is None:
            continue
        z1s = stars(r['Z1_p'])
        z1ks = stars(r['Z1xK_p'])
        z2ks = stars(r['Z2xK_p'])
        z3ks = stars(r['Z3xK_p'])
        fs = stars(r['F_p'])
        lines.append(
            f"| {r['label']} "
            f"| {r['Z1_coef']:.1f}{z1s} | {r['Z1_p']:.3f} "
            f"| {r['Z1xK_coef']:.1f}{z1ks} | {r['Z1xK_p']:.3f} "
            f"| {r['Z2xK_coef']:.1f}{z2ks} | {r['Z2xK_p']:.3f} "
            f"| {r['Z3xK_coef']:.1f}{z3ks} | {r['Z3xK_p']:.3f} "
            f"| {r['F_stat']:.2f}{fs} | {r['F_p']:.3f} "
            f"| {r['N']} | {r['Countries']} | {r['R2']:.3f} |"
        )

    return lines


def main():
    print("=" * 70)
    print("PHASE 9: Systematic Subsample Investigation of Z×KAOPEN Interactions")
    print("=" * 70)

    df = load_panel()
    country_stats = classify_countries(df)

    # Merge classifications into panel
    df = df.merge(country_stats[['kaopen_tercile', 'income_tercile', 'demo_stage', 'is_oecd']],
                  left_on='iso3', right_index=True, how='left')

    results = []

    # ===================================================================
    # Full sample (benchmark)
    # ===================================================================
    print("\n" + "=" * 70)
    print("FULL SAMPLE (benchmark)")
    print("=" * 70)
    results.append(run_extended_model(df, 'Full sample (140)'))

    # ===================================================================
    # Cut 1: KAOPEN terciles
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 1: KAOPEN Terciles")
    print("=" * 70)
    for tercile in ['low', 'medium', 'high']:
        sub = df[df['kaopen_tercile'] == tercile].copy()
        results.append(run_extended_model(sub, f'KAOPEN: {tercile}'))

    # ===================================================================
    # Cut 2: OECD vs non-OECD
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 2: OECD vs non-OECD")
    print("=" * 70)
    results.append(run_extended_model(df[df['is_oecd'] == True].copy(), 'OECD'))
    results.append(run_extended_model(df[df['is_oecd'] == False].copy(), 'Non-OECD'))

    # ===================================================================
    # Cut 3: Pre-GFC vs post-GFC
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 3: Pre-GFC vs Post-GFC")
    print("=" * 70)
    results.append(run_extended_model(df[df['year'] <= 2007].copy(), 'Pre-GFC (1986-2007)'))
    results.append(run_extended_model(df[df['year'] >= 2009].copy(), 'Post-GFC (2009-2024)'))

    # ===================================================================
    # Cut 4: Demographic stage
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 4: Demographic Stage (OADR terciles)")
    print("=" * 70)
    for stage in ['young', 'transitioning', 'aging']:
        sub = df[df['demo_stage'] == stage].copy()
        results.append(run_extended_model(sub, f'Demo: {stage}'))

    # ===================================================================
    # Cut 5: 2×2 income × openness
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 5: Income × Openness 2×2 Matrix")
    print("=" * 70)

    # Binary splits at median
    kaopen_median = country_stats['median_kaopen'].median()
    gdp_median = country_stats['median_gdp_pc'].median()

    high_inc_isos = set(country_stats[country_stats['median_gdp_pc'] >= gdp_median].index)
    low_inc_isos = set(country_stats[country_stats['median_gdp_pc'] < gdp_median].index)
    high_k_isos = set(country_stats[country_stats['median_kaopen'] >= kaopen_median].index)
    low_k_isos = set(country_stats[country_stats['median_kaopen'] < kaopen_median].index)

    cells = [
        ('High-inc + High-K', high_inc_isos & high_k_isos),
        ('High-inc + Low-K', high_inc_isos & low_k_isos),
        ('Low/mid-inc + High-K', low_inc_isos & high_k_isos),
        ('Low/mid-inc + Low-K', low_inc_isos & low_k_isos),
    ]
    for label, isos in cells:
        sub = df[df['iso3'].isin(isos)].copy()
        results.append(run_extended_model(sub, label))

    # ===================================================================
    # Cut 6: Excluding DFBETA-flagged countries
    # ===================================================================
    print("\n" + "=" * 70)
    print("CUT 6: Excluding DFBETA-Flagged Countries")
    print("=" * 70)
    df_excl = df[~df['iso3'].isin(DFBETA_FLAGGED)].copy()
    results.append(run_extended_model(df_excl, f'Excl {len(DFBETA_FLAGGED)} DFBETA-flagged'))

    # ===================================================================
    # Build output table
    # ===================================================================
    valid_results = [r for r in results if r is not None]
    lines = build_markdown_table(valid_results)

    # Add interpretive summary
    lines.extend(['', '*,**,*** denote significance at 10%, 5%, 1%.', ''])
    lines.append('## Interpretation')
    lines.append('')

    # Auto-detect where interactions survive
    sig_subs = [r['label'] for r in valid_results if r['F_p'] < 0.10]
    insig_subs = [r['label'] for r in valid_results if r['F_p'] >= 0.10]

    if sig_subs:
        lines.append(f'**Interactions significant (Joint F p<0.10):** {", ".join(sig_subs)}')
    else:
        lines.append('**Interactions significant (Joint F p<0.10):** None')
    lines.append('')
    if insig_subs:
        lines.append(f'**Interactions insignificant (Joint F p≥0.10):** {", ".join(insig_subs)}')
    lines.append('')

    # Country/obs accounting
    total_countries = df['iso3'].nunique()
    lines.append(f'Total countries in expanded sample: {total_countries}')

    # Save
    outpath = OUTPUT_DIR / "table_interaction_subsamples.md"
    with open(outpath, 'w') as f:
        f.write('\n'.join(lines) + '\n')
    print(f"\nSaved: {outpath}")

    # Print summary
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print(f"Ran {len(valid_results)} subsample regressions")
    print(f"Interactions significant in: {len(sig_subs)}/{len(valid_results)} subsamples")
    if sig_subs:
        print(f"  Significant: {sig_subs}")
    if insig_subs:
        print(f"  Insignificant: {insig_subs}")


if __name__ == "__main__":
    main()
